"""Wrapper around Aviary"""
from typing import Any, Dict, List, Mapping, Optional

import requests
from pydantic import Extra, Field, root_validator

from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens
from langchain.utils import get_from_dict_or_env

TIMEOUT = 60


class Aviary(LLM):
    """Allow you to use an Aviary.

    Aviary is a backend for hosted models. You can
    find out more about aviary at
    http://github.com/ray-project/aviary

    Has no dependencies, since it connects to backend
    directly.

    To get a list of the models supported on an
    aviary, follow the instructions on the web site to
    install the aviary CLI and then use:
    `aviary models`

    You must at least specify the environment
    variable or parameter AVIARY_URL.

    You may optionally specify the environment variable
    or parameter AVIARY_TOKEN.

    Example:
        .. code-block:: python

            from langchain.llms import Aviary
            light = Aviary(aviary_url='AVIARY_URL',
                            model='amazon/LightGPT')

            result = light.predict('How do you make fried rice?')
    """

    model: str
    aviary_url: str
    aviary_token: str = Field("", exclude=True)

    class Config:
        """Configuration for this pydantic object."""

        extra = Extra.forbid

    @root_validator(pre=True)
    def validate_environment(cls, values: Dict) -> Dict:
        """Validate that api key and python package exists in environment."""
        aviary_url = get_from_dict_or_env(values, "aviary_url", "AVIARY_URL")
        if not aviary_url.endswith("/"):
            aviary_url += "/"
        values["aviary_url"] = aviary_url
        aviary_token = get_from_dict_or_env(
            values, "aviary_token", "AVIARY_TOKEN", default=""
        )
        values["aviary_token"] = aviary_token

        aviary_endpoint = aviary_url + "models"
        headers = {"Authorization": f"Bearer {aviary_token}"} if aviary_token else {}
        try:
            response = requests.get(aviary_endpoint, headers=headers)
            result = response.json()
            # Confirm model is available
            if values["model"] not in result:
                raise ValueError(
                    f"{aviary_url} does not support model {values['model']}."
                )

        except requests.exceptions.RequestException as e:
            raise ValueError(e)

        return values

    @property
    def _identifying_params(self) -> Mapping[str, Any]:
        """Get the identifying parameters."""
        return {
            "aviary_url": self.aviary_url,
            "aviary_token": self.aviary_token,
        }

    @property
    def _llm_type(self) -> str:
        """Return type of llm."""
        return "aviary"

    @property
    def headers(self) -> Dict[str, str]:
        if self.aviary_token:
            return {"Authorization": f"Bearer {self.aviary_token}"}
        else:
            return {}

    def _call(
        self,
        prompt: str,
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
    ) -> str:
        """Call out to Aviary
        Args:
            prompt: The prompt to pass into the model.

        Returns:
            The string generated by the model.

        Example:
            .. code-block:: python

                response = aviary("Tell me a joke.")
        """
        url = self.aviary_url + "query/" + self.model.replace("/", "--")
        response = requests.post(
            url,
            headers=self.headers,
            json={"prompt": prompt},
            timeout=TIMEOUT,
        )
        try:
            text = response.json()[self.model]["generated_text"]
        except requests.JSONDecodeError as e:
            raise ValueError(
                f"Error decoding JSON from {url}. Text response: {response.text}",
            ) from e
        if stop:
            text = enforce_stop_tokens(text, stop)
        return text
